import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        long n = in.nextInt(), m = in.nextInt(), a = in.nextInt(), b = in.nextInt();
        long count = 0;

        for(long x = 0; x <= Math.min(n / 2, m); x++){
            long y = Math.min(n - x * 2, (m - x) / 2);
            count = Math.max(count, a * x + b * y);
        }
        System.out.println(count);
    }
}